# ---> Neuro-inspired Sparse Training (NIST) manuscript <---
# Authors: Mohsen Kamelian Rad, Sotiris Moschoyiannis, Lu Yin, Roman Bauer---
# This script contains the training loop, used to modify the model, sparsity ratio at the end, and dataset directories
# For modifying the sparsity ratio, "magnitude_prune_and_mask function" is defined and used at the end of the training loop.
# Note that mask_index and mask_step need to be adjusted according to the Linear head for VGG (see comments on these parameters for modifying them)
# For training the original model set base_model.trainable = False in setup.py / for sparse training we free all convolutional blocks except the last one.

import tensorflow as tf
import numpy as np
from tensorflow.keras import datasets, layers
from nist import nist_layer
from setup import *
from load_dataset import get_dataset

import os
import gc

setup_environment()
# List all four parts
train_dirs = [
    "/mnt/fast/nobackup/users/mk02339/nist/imagenet100/train.X1",
    "/mnt/fast/nobackup/users/mk02339/nist/imagenet100/train.X2",
    "/mnt/fast/nobackup/users/mk02339/nist/imagenet100/train.X3",
    "/mnt/fast/nobackup/users/mk02339/nist/imagenet100/train.X4"
]
val_dir = "/mnt/fast/nobackup/users/mk02339/nist/imagenet100/val.X"
train_ds, val_ds = get_dataset(train_dirs, val_dir, 
                            batch_size= 64, shuffle= 5000,
                            pref =2)

trainin_accs = []
validation_accs = []
trainin_accs2 = []
validation_accs2 = []
tf.keras.backend.clear_session()
trials = 5
for i in range(trials):
    if i < 5:
        # Define the desired classifier head (by default: it is VGG16's original setting, 2x4096 hidden layers. The second number inside each tuple correspond to the neuroseed factor for that layer.
        trial_fc_heads= [('nist_layer', 4096, 1), 'dropout', ('nist_layer', 4096, 2), 'dropout'
            ]
    else:
        trial_fc_heads = [
            ]
    model = build_model(trial_fc_heads)
    model.summary()
    mask_index = 3   # This needs to be adjusted accordig to the model, index of the first Dense layer in the sequential model
    mask_step = 2    # '2' if there are dropout layers between Dense layers, '1' if consequent Dense layers
    '''
    if i > 100:
        optimizer= tf.keras.optimizers.SGD(learning_rate=0.0001, momentum =0.9)
        print('Testing with SGD')
    elif i == 2:
        optimizer =tf.keras.optimizers.SGD(learning_rate=0.0001, momentum =0.9)
    else:
        optimizer= tf.keras.optimizers.Adam(learning_rate=0.0001)
        print('Switched to Adam')'''
    optimizer= tf.keras.optimizers.AdamW(learning_rate=0.0001)
    loss_fn= tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    '''
    masks = []
    for layer in model.layers:
        if isinstance(layer, nist_layer):
            weights, biases = layer.get_weights()
            #masks.append(np.ones_like(weights))
            masks.append(layer.w)
            layer.raw_kernel.assign(layer.raw_kernel * layer.w)
            print('1\n')
        else:
            masks.append(None)
            print('0\n')
    # Pruning Function_ generating masks
    def magnitude_prune_and_mask(model, pruning_percentage):
        # Step 1: Collect all weights across nist layers
        all_weights =[]
        layer_weight_shape = []
        for layer in model.layers:
            if isinstance(layer, nist_layer):
                weights, _ = layer.get_weights()
                all_weights.append(np.abs(weights).flatten())
                layer_weight_shape.append(weights.shape)
        # Step 2: Concatenate all weights into a single array
        all_weights_flat = np.concatenate(all_weights)
        # Step 3: Compute global threshold
        k = int(pruning_percentage * all_weights_flat.size)
        if k == 0:
            print("Pruning percentage too low — no weights will be pruned.")
            threshold = -np.inf  # effectively disables pruning
        else:
            threshold = np.partition(all_weights_flat, k)[k]
        print(f"Global pruning threshold: {threshold:.4f}")
        # Step 4: Generate new masks and apply pruned weights
        new_masks = []
        #weight_idx = 0  # track position in all_weights_flat
        for layer in model.layers:
            if isinstance(layer, nist_layer):
                weights, biases = layer.get_weights()
                mask = (np.abs(weights) >= threshold).astype(np.float32)
                pruned_weights = weights * mask
                layer.set_weights([pruned_weights, biases])
                new_masks.append(mask)
                print('1\n')
            else:
                new_masks.append(None)
                print('0\n')
        return new_masks
    # Metrics
    train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    @tf.function
    def train_step(x_batch, y_batch):
        with tf.GradientTape() as tape:
            logits = model(x_batch, training=True)
            loss = loss_fn(y_batch, logits)
        grads = tape.gradient(loss, model.trainable_weights)
        # Apply masks to gradients
        masked_grads = []
        mask_idx = mask_index
        for grad, var in zip(grads, model.trainable_weights):
            if 'raw_kernel' in var.name:
                mask = masks[mask_idx]
                #print(var.shape)
                grad = grad * mask
                var = var * mask
                mask_idx += mask_step
            masked_grads.append(grad)
        optimizer.apply_gradients(zip(masked_grads, model.trainable_weights))
        # Re-mask weights after update
        mask_idx = mask_index
        for layer in model.layers:
            if isinstance(layer, nist_layer):
                layer.raw_kernel.assign(layer.raw_kernel * masks[mask_idx])
                mask_idx += mask_step
        train_acc_metric.update_state(y_batch, logits)
        return loss
    # Validation step
    @tf.function
    def val_step(x, y):
        val_logits = model(x, training=False)
        val_acc_metric.update_state(y, val_logits)
    # Run training loop
    epochs = 60
    with tf.device('/GPU:0'):
        for epoch in range(epochs):
            print(f"\nEpoch {epoch + 1}/{epochs}")
            epoch_loss = 0.0
            batch_count = 0
            for x_batch, y_batch in train_ds:
                # print(f"x_batch shape: {x_batch.shape}, y_batch shape: {y_batch.shape}")
                loss = train_step(x_batch, y_batch)
                epoch_loss += loss.numpy()
                batch_count += 1
            avg_loss = epoch_loss / batch_count
            train_acc = train_acc_metric.result().numpy()
            # Run validation
            for x_batch_val, y_batch_val in val_ds:
                val_step(x_batch_val, y_batch_val)
            val_acc = val_acc_metric.result().numpy()
            if i < 5:
                trainin_accs.append(train_acc)
                validation_accs.append(val_acc)
            else:
                trainin_accs2.append(train_acc)
                validation_accs2.append(val_acc)
            np.save('train_accs_checkpoint.npy', trainin_accs)
            np.save('val_accs_checkpoint.npy', validation_accs)
            np.save('train_accs2_checkpoint.npy', trainin_accs2)
            np.save('val_accs2_checkpoint.npy', validation_accs2)
            #print(psutil.virtual_memory())
            print(f"Loss: {avg_loss:.4f}, Train Accuracy: {train_acc:.4f}, Validation Accuracy: {val_acc:.4f}")
            # Reset metrics for next epoch
            train_acc_metric.reset_state()
            val_acc_metric.reset_state()
            # Apply pruning after 5 epochs
            if epoch >= 5:
                print("Applying magnitude-based pruning...")
                masks = magnitude_prune_and_mask(model, pruning_percentage= 0.999)            # modify this to change the final sparsity
            if epoch > 4:
                #optimizer= tf.keras.optimizers.SGD(learning_rate=0.0001, momentum= 0.9)
                optimizer = tf.keras.optimizers.AdamW(learning_rate=0.00001)
    '''
    model.compile(optimizer=optimizer,
        loss=loss_fn,
        metrics=['accuracy'])
	#model.summary()
    history = model.fit(train_ds, epochs= 60, validation_data=val_ds, steps_per_epoch=1000)
    train_acc = history.history['accuracy']           # or 'acc' in older versions
    val_acc = history.history['val_accuracy']
    trainin_accs.append(train_acc)
    validation_accs.append(val_acc)
    np.save('train_accs_checkpoint.npy', trainin_accs)
    np.save('val_accs_checkpoint.npy', validation_accs)
    del model
    tf.keras.backend.clear_session()
    gc.collect()
tr1 = np.reshape(np.array(trainin_accs), (5,60))
val1 = np.reshape(np.array(validation_accs), (5, 60))
tr2 = np.reshape(np.array(trainin_accs2), (5,60))
val2 = np.reshape(np.array(validation_accs2), (5, 60))
np.save('train_accs_1.npy', tr1)
np.save('Val_accs_1.npy', val1)
np.save('train_accs_2.npy', tr2)
np.save('Val_accs_2.npy', val2)

